Skip to content

feat: add KV caching support for Wan models#400

Open
Perseus14 wants to merge 1 commit intomainfrom
wan_kv_cache
Open

feat: add KV caching support for Wan models#400
Perseus14 wants to merge 1 commit intomainfrom
wan_kv_cache

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 6, 2026

This Pull Request implements the KV Cache optimization for all WAN models (WAN 2.1 & 2.2, both Text-to-Video and Image-to-Video) as well as the WAN VACE model.. This optimization pre-computes the Key and Value projections for text and image embeddings before the denoising loop (since they remain constant throughout)

Additionally, this PR introduces a hardware-aware Dynamic Image Alignment Padding optimization for next-generation TPUs as well as fix minor bugs and typos.


Key Changes

1. KV Cache Optimization (WAN & VACE)

  • Attention Level (FlaxWanAttention): Modified attention_flax.py to accept cached_kv. If present, key/value projections are bypassed. Added a robust compute_kv method to pre-project text (and image) states.
  • Transformer Level (WanModel & WanVACEModel):
    • Updated transformer_wan.py to support KV cache propagation. Added compute_kv_cache to precompute block-level cached keys/values.
    • Integrated skip_embeddings inside WanTimeTextImageEmbedding to bypass redundant embedding layers when using cached states.
    • Extended the same KV cache precomputation and propagation logic to transformer_wan_vace.py for both VACE blocks and main blocks.
  • Pipeline Level:
    • Updated forward pass helper signatures in wan_pipeline.py and wan_vace_pipeline_2_1.py to accept and propagate KV caches.
    • Updated all denoising loops (including VACE) to pre-compute the KV cache before starting the loop and reuse it at every step when use_kv_cache=True.
  • Config Defaults: Added use_kv_cache: False to all default .yml configuration files to ensure backward compatibility.

2. Dynamic Image Alignment Padding (Trillium & Ironwood Optimization)

  • Problem: Image embeddings were previously hardcoded to pad to multiples of 128. While optimal for older MXU tile sizes (TPU v4, v5p/v5e), next-generation hardware like Trillium (v6e) and Ironwood (v7x) utilize larger $256 \times 256$ MXU tile structures.
  • Solution: Replaced hardcoded values with dynamic TPU hardware detection (get_tpu_type()). Both attention_flax.py and embeddings_flax.py (NNXWanImageEmbedding) now dynamically adjust image alignment padding:
    • 256-alignment on Trillium (v6e) and Ironwood (7x) to perfectly match larger hardware tiles.
    • 128-alignment fallback on older TPU architectures (v5p and below).

3. Critical Bug Fixes

  • Fixed prompt=None handling in Pipeline: Updated encode_prompt and __call__ in wan_pipeline.py to infer batch_size from prompt_embeds when prompt is None, unblocking tests that pass pre-computed embeddings without raw text prompts.
  • Fixed getattr with defaults in pyconfig: Changed __getattr__ in HyperParameters class in pyconfig.py to raise AttributeError instead of ValueError when a key is not found. This allows standard Python getattr(config, "key", default) to correctly return the default value instead of crashing.
  • Fixed deprecated jax.tree_map to jax.tree.map in _cast_floating_to.
  • Fixed typo in error message ("T2Vin" -> "T2V in").

Detailed File Changes

Models

  • attention_flax.py:
    • Imported get_tpu_type and TpuType for dynamic hardware-aware image alignment padding.
    • Integrated cached_kv routing inside FlaxWanAttention.__call__.
    • Implemented compute_kv support for both T2V and I2V cross-attentions.
  • transformer_wan.py:
    • Added skip_embeddings parameter inside WanTimeTextImageEmbedding to bypass redundant text/image projections.
    • Updated WanTransformerBlock and WanModel to handle cached_kv / kv_cache passing.
    • Implemented WanModel.compute_kv_cache to precompute block-level cached keys/values across scan and non-scan layers.
  • transformer_wan_vace.py:
    • Added compute_kv to WanVACETransformerBlock.
    • Added compute_kv_cache to WanVACEModel to precompute block-level cached keys/values for both VACE and main branches.
    • Updated __call__ to accept and propagate kv_cache to both VACE and main blocks.
  • modeling_flax_utils.py:
    • Fixed deprecated jax.tree_map to jax.tree.map in _cast_floating_to.
  • embeddings_flax.py:
    • Updated NNXWanImageEmbedding to dynamically align to 256 for v6e/7x and 128 otherwise, avoiding shape mismatches during cross-attention.

Pipelines

  • wan_pipeline.py:
    • Updated transformer_forward_pass, transformer_forward_pass_full_cfg, and transformer_forward_pass_cfg_cache to accept and pass kv_cache.
    • Fixed NoneType error when calculating batch size by checking if prompt is None.
  • wan_vace_pipeline_2_1.py:
    • Added use_kv_cache parameter to pipeline calls.
    • Pre-computed kv_cache before the denoising loop and passed it to transformer_forward_pass in run_inference.
  • wan_pipeline_2_1.py & wan_pipeline_2_2.py:
    • Added use_kv_cache parameter to pipeline calls and pre-computed kv_cache and rotary_emb prior to the denoising loop.
  • wan_pipeline_i2v_2p1.py & wan_pipeline_i2v_2p2.py:
    • Integrated dynamic pre-computed kv_cache support for I2V workflows.
  • generate_wan.py:
    • Fixed typo in error message ("T2Vin" -> "T2V in").

Configs

  • Configs (base_wan_1_3b.yml, base_wan_14b.yml, base_wan_27b.yml, base_wan_i2v_14b.yml, base_wan_i2v_27b.yml):
    • Added use_kv_cache: False for all default configs.
    • Modified prompt in I2V config files

Performance Note

  • Observed Latency Savings:
    • ~0.2s on TPU v7x-8 (Ironwood)
    • ~0.5s on TPU v6e-8 (Trillium)
  • Analysis: The latency savings during a full denoising run are minimal. This is mathematically expected because the cross-attention Key/Value projections operate on a very small text prompt sequence (typically 512 tokens). The computational FLOPs saved by caching these projections represent a negligible fraction ($< 0.01%$) of the total workload compared to the massive latent sequence length processed by the self-attention and FFN layers at every step of the denoising loop.

Note

Test Configuration: 720p | 81 frames | 40 steps
Hardware: TPU 7x-8
JAX Version: v0.10.0

Model Variant BaselIne Generation Time Current Generation Time Video Link
WAN2.2 T2V 132.4s 132.2s Link
WAN2.2 I2V 133.6s 133.3s Link
WAN2.1 T2V 132.2s 132.1s Link
WAN2.1 I2V 142.8s 142.7s Link

Conclusion: No visual change from baseline across all tested variants.

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 6, 2026 10:21
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

@Perseus14 Perseus14 requested review from mbohlool and prishajain1 May 6, 2026 10:30
@Perseus14 Perseus14 self-assigned this May 7, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces significant performance optimizations for WAN models, specifically KV caching and hardware-aware dynamic image alignment padding for next-gen TPUs. The implementation of pre-computed RoPE and KV caches outside the denoising loop is an excellent architectural improvement.

🔍 General Feedback

  • KV Cache Coverage: While the KV cache and RoPE pre-computation are implemented for all models, they are currently missing from the scan_diffusion_loop paths in the T2V 2.1, T2V 2.2, and I2V 2.1 pipelines. This prevents users of the performance-oriented scan mode from benefiting from these optimizations. I have provided suggestions to propagate these caches into the scan bodies.
  • Hardware-Aware Padding: The dynamic adjustment of image alignment padding (256 vs 128) based on TPU type is correctly implemented across the attention and embedding layers.
  • Code Organization: Moving the concatenation of embeds and RoPE computation outside the denoising loop significantly reduces redundant computation at every step.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request implements important performance optimizations for WAN models, including KV caching and hardware-aware dynamic image alignment padding for TPU v6e and v7x. The pre-computation of RoPE and KV caches is a significant improvement for inference efficiency.

🔍 General Feedback

  • Missing Scan Loop Updates: While the optimizations are correctly implemented for the standard and CFG-cache denoising loops, the scan_diffusion_loop paths in the T2V 2.1, T2V 2.2, and I2V 2.1 pipelines have not been updated to use the new pre-computed KV caches and RoPE. This means that users utilizing the scan mode will not see the expected performance gains and will still incur the cost of redundant RoPE computations at each step.
  • Hardware-Awareness: The dynamic padding logic (256 for Trillium/Ironwood, 128 otherwise) is well-integrated into the attention and embedding modules.
  • Consistency: I2V 2.2 was correctly updated to handle these changes in its scan loop, but the other pipelines were missed. Synchronizing these would ensure a consistent performance profile across all WAN model variants.

Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully implements KV Cache optimization for all WAN models (2.1 & 2.2, T2V & I2V), providing significant inference speedups on TPU. It also introduces dynamic image alignment padding, optimizing performance for Trillium (v6e) and Ironwood (v7x) accelerators. The overall implementation is solid and well-integrated into the existing pipeline and model structures.

🔍 General Feedback

  • KV Cache Integration: The propagation of the KV cache from the pipeline through the transformer and attention layers is handled correctly, including support for complex CFG and caching strategies.
  • Trillium Optimizations: Dynamically adjusting padding to 256 for newer TPU generations is a great performance win.
  • Critical Path Bug: Identified a NameError in the WanModel's TI2V path (per_token_t=True) that needs immediate attention.
  • Consistency: Minor inconsistencies in RoPE dummy shape calculations across pipelines were noted but do not affect correctness.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This second part of the review adds specific inline comments and suggestions for the core model and transformer changes.

🔍 General Feedback

  • TI2V Path Fix: Corrected a NameError in WanModel.__call__ when per_token_t is used.
  • Attention Optimizations: Suggested improvements for TPU alignment detection and sequence length handling in FlaxWanAttention.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Detailed inline comments and suggestions for model improvements and bug fixes.

🔍 General Feedback

  • TI2V Path Fix: Corrected a NameError in WanModel.__call__ when per_token_t is used.
  • Attention Optimizations: Suggested improvements for TPU alignment detection and sequence length handling in FlaxWanAttention.

Comment thread src/maxdiffusion/models/attention_flax.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Final part of the review with the critical fix for the TI2V path.

🔍 General Feedback

  • TI2V Path Fix: Corrected a NameError in WanModel.__call__ when per_token_t is used.

@Perseus14 Perseus14 force-pushed the wan_kv_cache branch 3 times, most recently from ec2f6bb to 8467dbb Compare May 7, 2026 19:02
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR implements KV caching for WAN models, providing a significant optimization for inference performance. It also introduces dynamic, hardware-aware alignment padding to ensure optimal execution on next-generation TPU architectures. The changes are largely consistent across the main model variants and pipelines, but there are some omissions in the Animate model and inconsistencies in parameter propagation that should be addressed.

🔍 General Feedback

  • Optimization: The pre-computation of RoPE and KV cache is a welcome improvement that aligns with JAX/Flax best practices for efficient inference.
  • Hardware-Awareness: The use of dynamic alignment (128 vs 256) based on TPU type is a great addition for performance portability across Trillium and earlier generations.
  • Omission: WanAnimateTransformer3DModel appears to have been missed in this update, despite the PR's intent to cover all WAN models.
  • Consistency: Ensure all WanTransformerBlock instantiations (especially in non-scan paths) propagate the new parameters (use_base2_exp, etc.) to avoid behavioral discrepancies.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR implements KV caching for WAN models, providing a significant optimization for inference performance. It also introduces dynamic, hardware-aware alignment padding to ensure optimal execution on next-generation TPU architectures (Trillium/Ironwood). While the core implementation is solid, there are omissions in the Animate model and inconsistencies in parameter propagation across non-scan code paths.

🔍 General Feedback

  • Optimization: The pre-computation of RoPE and KV cache is a significant performance improvement for long sequence generation.
  • Hardware-Awareness: The dynamic alignment logic based on TPU type is excellent for ensuring optimal MXU utilization across different hardware generations.
  • Omission: WanAnimateTransformer3DModel needs to be updated with KV cache support to fulfill the goal of supporting "all WAN models".
  • Consistency: Ensure all WanTransformerBlock instantiations propagate use_base2_exp and use_experimental_scheduler to maintain behavioral parity between scan and non-scan paths.

Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan.py
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request implements KV caching for WAN models and adds dynamic image alignment padding for next-gen TPUs. The implementation is technically sound, significantly improves inference efficiency by avoiding redundant text embedding projections, and includes a comprehensive unit test for verification.

🔍 General Feedback

  • KV Caching Logic: The addition of compute_kv_cache to WanModel and its integration into the denoising loops is well-executed.
  • TPU Optimization: The dynamic alignment padding (up to 256 for TPU v6e/v7) is a critical optimization for performance on new hardware.
  • Logging: Recommend replacing standard print statements with max_logging to maintain consistency across the codebase.
  • Maintainability: The fix in pyconfig.py correctly aligns with Python's expected behavior for __getattr__.

Comment thread src/maxdiffusion/models/embeddings_flax.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This review includes additional feedback on logging practices and consistency across the new caching implementations.

🔍 General Feedback

  • Consistency: Replaced remaining print statements with max_logging to ensure all diagnostic output follows the project's standard logging infrastructure.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@Perseus14 Perseus14 force-pushed the wan_kv_cache branch 2 times, most recently from 960afef to 99a0eb0 Compare May 10, 2026 12:40
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements KV Cache optimization for all WAN models and the WAN VACE model, significantly reducing redundant computations during the denoising loop. It also introduces hardware-aware Dynamic Image Alignment Padding to optimize performance on next-generation TPU architectures (v6e and v7x).

🔍 General Feedback

  • Efficiency: The KV cache optimization is well-integrated across all pipelines (T2V, I2V, VACE) and models. While the latency savings are modest for shorter prompts, it's a solid architectural improvement.
  • Hardware Optimization: The dynamic TPU type detection for image alignment padding is a great addition for future-proofing the codebase.
  • Robustness: The fix for prompt=None in pipelines and the AttributeError fix in pyconfig improve the overall stability and usability of the CLI.
  • Parity: The new wan_kv_cache_test.py verifies that the optimization does not introduce regressions in output quality.

Comment thread src/maxdiffusion/models/modeling_flax_utils.py
Comment on lines +777 to +805
timestep_proj,
encoder_hidden_states_out,
encoder_hidden_states_image_out,
encoder_attention_mask_out,
) = self.condition_embedder(
timestep,
encoder_hidden_states,
encoder_hidden_states_image,
encoder_attention_mask,
) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
skip_embeddings=(kv_cache is not None),
)
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)

if encoder_hidden_states_image is not None:
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
if encoder_attention_mask is not None:
if encoder_attention_mask is None:
encoder_attention_mask = encoder_attention_mask_out

if encoder_hidden_states_image_out is not None:
# Workaround: When kv_cache is used, skip_embeddings=True causes raw image (1280)
# and text (4096) embeddings to have different feature dimensions.
# We pad them with zeros to match dimensions so they can be concatenated.
# This maintains the correct total sequence length for FlaxWanAttention's internal slicing.
# Since these values are ignored when cached_kv is present, the content doesn't matter.
if kv_cache is not None and encoder_hidden_states_image_out.shape[-1] != encoder_hidden_states_out.shape[-1]:
img_dim = encoder_hidden_states_image_out.shape[-1]
text_dim = encoder_hidden_states_out.shape[-1]
if img_dim < text_dim:
pad_shape = (
encoder_hidden_states_image_out.shape[0],
encoder_hidden_states_image_out.shape[1],
text_dim - img_dim,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 This workaround for concatenating raw embeddings of different dimensions is functional but slightly brittle. While it's true that these values are ignored when `cached_kv` is present, it might be cleaner to handle the padding logic within `WanTimeTextImageEmbedding` or ensure that `condition_embedder` returns compatible shapes even when skipping projections.
Suggested change
timestep_proj,
encoder_hidden_states_out,
encoder_hidden_states_image_out,
encoder_attention_mask_out,
) = self.condition_embedder(
timestep,
encoder_hidden_states,
encoder_hidden_states_image,
encoder_attention_mask,
) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
skip_embeddings=(kv_cache is not None),
)
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
if encoder_hidden_states_image is not None:
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
if encoder_attention_mask is not None:
if encoder_attention_mask is None:
encoder_attention_mask = encoder_attention_mask_out
if encoder_hidden_states_image_out is not None:
# Workaround: When kv_cache is used, skip_embeddings=True causes raw image (1280)
# and text (4096) embeddings to have different feature dimensions.
# We pad them with zeros to match dimensions so they can be concatenated.
# This maintains the correct total sequence length for FlaxWanAttention's internal slicing.
# Since these values are ignored when cached_kv is present, the content doesn't matter.
if kv_cache is not None and encoder_hidden_states_image_out.shape[-1] != encoder_hidden_states_out.shape[-1]:
img_dim = encoder_hidden_states_image_out.shape[-1]
text_dim = encoder_hidden_states_out.shape[-1]
if img_dim < text_dim:
pad_shape = (
encoder_hidden_states_image_out.shape[0],
encoder_hidden_states_image_out.shape[1],
text_dim - img_dim,
if kv_cache is not None and encoder_hidden_states_image_out.shape[-1] != encoder_hidden_states_out.shape[-1]:
img_dim = encoder_hidden_states_image_out.shape[-1]
text_dim = encoder_hidden_states_out.shape[-1]
if img_dim < text_dim:
pad_shape = (
encoder_hidden_states_image_out.shape[0],
encoder_hidden_states_image_out.shape[1],
text_dim - img_dim,
)
encoder_hidden_states_image_out = jnp.concatenate(
[
encoder_hidden_states_image_out,
jnp.zeros(pad_shape, dtype=encoder_hidden_states_image_out.dtype),
],
axis=-1,
)
else:
pad_shape = (
encoder_hidden_states_out.shape[0],
encoder_hidden_states_out.shape[1],
img_dim - text_dim,
)
encoder_hidden_states_out = jnp.concatenate(
[encoder_hidden_states_out, jnp.zeros(pad_shape, dtype=encoder_hidden_states_out.dtype)], axis=-1
)

Comment thread src/maxdiffusion/pyconfig.py
Comment thread src/maxdiffusion/generate_wan.py
Comment thread src/maxdiffusion/tests/wan_kv_cache_test.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants